// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2011-2014 Gael Guennebaud <gael.guennebaud@inria.fr>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.

#ifndef EIGEN_ITERATIVE_SOLVER_BASE_H
#define EIGEN_ITERATIVE_SOLVER_BASE_H

namespace Eigen {

namespace internal {

    template <typename MatrixType> struct is_ref_compatible_impl
    {
    private:
        template <typename T0> struct any_conversion
        {
            template <typename T> any_conversion(const volatile T&);
            template <typename T> any_conversion(T&);
        };
        struct yes
        {
            int a[1];
        };
        struct no
        {
            int a[2];
        };

        template <typename T> static yes test(const Ref<const T>&, int);
        template <typename T> static no test(any_conversion<T>, ...);

    public:
        static MatrixType ms_from;
        enum
        {
            value = sizeof(test<MatrixType>(ms_from, 0)) == sizeof(yes)
        };
    };

    template <typename MatrixType> struct is_ref_compatible
    {
        enum
        {
            value = is_ref_compatible_impl<typename remove_all<MatrixType>::type>::value
        };
    };

    template <typename MatrixType, bool MatrixFree = !internal::is_ref_compatible<MatrixType>::value> class generic_matrix_wrapper;

    // We have an explicit matrix at hand, compatible with Ref<>
    template <typename MatrixType> class generic_matrix_wrapper<MatrixType, false>
    {
    public:
        typedef Ref<const MatrixType> ActualMatrixType;
        template <int UpLo> struct ConstSelfAdjointViewReturnType
        {
            typedef typename ActualMatrixType::template ConstSelfAdjointViewReturnType<UpLo>::Type Type;
        };

        enum
        {
            MatrixFree = false
        };

        generic_matrix_wrapper() : m_dummy(0, 0), m_matrix(m_dummy) {}

        template <typename InputType> generic_matrix_wrapper(const InputType& mat) : m_matrix(mat) {}

        const ActualMatrixType& matrix() const { return m_matrix; }

        template <typename MatrixDerived> void grab(const EigenBase<MatrixDerived>& mat)
        {
            m_matrix.~Ref<const MatrixType>();
            ::new (&m_matrix) Ref<const MatrixType>(mat.derived());
        }

        void grab(const Ref<const MatrixType>& mat)
        {
            if (&(mat.derived()) != &m_matrix)
            {
                m_matrix.~Ref<const MatrixType>();
                ::new (&m_matrix) Ref<const MatrixType>(mat);
            }
        }

    protected:
        MatrixType m_dummy;  // used to default initialize the Ref<> object
        ActualMatrixType m_matrix;
    };

    // MatrixType is not compatible with Ref<> -> matrix-free wrapper
    template <typename MatrixType> class generic_matrix_wrapper<MatrixType, true>
    {
    public:
        typedef MatrixType ActualMatrixType;
        template <int UpLo> struct ConstSelfAdjointViewReturnType
        {
            typedef ActualMatrixType Type;
        };

        enum
        {
            MatrixFree = true
        };

        generic_matrix_wrapper() : mp_matrix(0) {}

        generic_matrix_wrapper(const MatrixType& mat) : mp_matrix(&mat) {}

        const ActualMatrixType& matrix() const { return *mp_matrix; }

        void grab(const MatrixType& mat) { mp_matrix = &mat; }

    protected:
        const ActualMatrixType* mp_matrix;
    };

}  // namespace internal

/** \ingroup IterativeLinearSolvers_Module
  * \brief Base class for linear iterative solvers
  *
  * \sa class SimplicialCholesky, DiagonalPreconditioner, IdentityPreconditioner
  */
template <typename Derived> class IterativeSolverBase : public SparseSolverBase<Derived>
{
protected:
    typedef SparseSolverBase<Derived> Base;
    using Base::m_isInitialized;

public:
    typedef typename internal::traits<Derived>::MatrixType MatrixType;
    typedef typename internal::traits<Derived>::Preconditioner Preconditioner;
    typedef typename MatrixType::Scalar Scalar;
    typedef typename MatrixType::StorageIndex StorageIndex;
    typedef typename MatrixType::RealScalar RealScalar;

    enum
    {
        ColsAtCompileTime = MatrixType::ColsAtCompileTime,
        MaxColsAtCompileTime = MatrixType::MaxColsAtCompileTime
    };

public:
    using Base::derived;

    /** Default constructor. */
    IterativeSolverBase() { init(); }

    /** Initialize the solver with matrix \a A for further \c Ax=b solving.
    *
    * This constructor is a shortcut for the default constructor followed
    * by a call to compute().
    *
    * \warning this class stores a reference to the matrix A as well as some
    * precomputed values that depend on it. Therefore, if \a A is changed
    * this class becomes invalid. Call compute() to update it with the new
    * matrix A, or modify a copy of A.
    */
    template <typename MatrixDerived> explicit IterativeSolverBase(const EigenBase<MatrixDerived>& A) : m_matrixWrapper(A.derived())
    {
        init();
        compute(matrix());
    }

    ~IterativeSolverBase() {}

    /** Initializes the iterative solver for the sparsity pattern of the matrix \a A for further solving \c Ax=b problems.
    *
    * Currently, this function mostly calls analyzePattern on the preconditioner. In the future
    * we might, for instance, implement column reordering for faster matrix vector products.
    */
    template <typename MatrixDerived> Derived& analyzePattern(const EigenBase<MatrixDerived>& A)
    {
        grab(A.derived());
        m_preconditioner.analyzePattern(matrix());
        m_isInitialized = true;
        m_analysisIsOk = true;
        m_info = m_preconditioner.info();
        return derived();
    }

    /** Initializes the iterative solver with the numerical values of the matrix \a A for further solving \c Ax=b problems.
    *
    * Currently, this function mostly calls factorize on the preconditioner.
    *
    * \warning this class stores a reference to the matrix A as well as some
    * precomputed values that depend on it. Therefore, if \a A is changed
    * this class becomes invalid. Call compute() to update it with the new
    * matrix A, or modify a copy of A.
    */
    template <typename MatrixDerived> Derived& factorize(const EigenBase<MatrixDerived>& A)
    {
        eigen_assert(m_analysisIsOk && "You must first call analyzePattern()");
        grab(A.derived());
        m_preconditioner.factorize(matrix());
        m_factorizationIsOk = true;
        m_info = m_preconditioner.info();
        return derived();
    }

    /** Initializes the iterative solver with the matrix \a A for further solving \c Ax=b problems.
    *
    * Currently, this function mostly initializes/computes the preconditioner. In the future
    * we might, for instance, implement column reordering for faster matrix vector products.
    *
    * \warning this class stores a reference to the matrix A as well as some
    * precomputed values that depend on it. Therefore, if \a A is changed
    * this class becomes invalid. Call compute() to update it with the new
    * matrix A, or modify a copy of A.
    */
    template <typename MatrixDerived> Derived& compute(const EigenBase<MatrixDerived>& A)
    {
        grab(A.derived());
        m_preconditioner.compute(matrix());
        m_isInitialized = true;
        m_analysisIsOk = true;
        m_factorizationIsOk = true;
        m_info = m_preconditioner.info();
        return derived();
    }

    /** \internal */
    EIGEN_CONSTEXPR Index rows() const EIGEN_NOEXCEPT { return matrix().rows(); }

    /** \internal */
    EIGEN_CONSTEXPR Index cols() const EIGEN_NOEXCEPT { return matrix().cols(); }

    /** \returns the tolerance threshold used by the stopping criteria.
    * \sa setTolerance()
    */
    RealScalar tolerance() const { return m_tolerance; }

    /** Sets the tolerance threshold used by the stopping criteria.
    *
    * This value is used as an upper bound to the relative residual error: |Ax-b|/|b|.
    * The default value is the machine precision given by NumTraits<Scalar>::epsilon()
    */
    Derived& setTolerance(const RealScalar& tolerance)
    {
        m_tolerance = tolerance;
        return derived();
    }

    /** \returns a read-write reference to the preconditioner for custom configuration. */
    Preconditioner& preconditioner() { return m_preconditioner; }

    /** \returns a read-only reference to the preconditioner. */
    const Preconditioner& preconditioner() const { return m_preconditioner; }

    /** \returns the max number of iterations.
    * It is either the value set by setMaxIterations or, by default,
    * twice the number of columns of the matrix.
    */
    Index maxIterations() const { return (m_maxIterations < 0) ? 2 * matrix().cols() : m_maxIterations; }

    /** Sets the max number of iterations.
    * Default is twice the number of columns of the matrix.
    */
    Derived& setMaxIterations(Index maxIters)
    {
        m_maxIterations = maxIters;
        return derived();
    }

    /** \returns the number of iterations performed during the last solve */
    Index iterations() const
    {
        eigen_assert(m_isInitialized && "ConjugateGradient is not initialized.");
        return m_iterations;
    }

    /** \returns the tolerance error reached during the last solve.
    * It is a close approximation of the true relative residual error |Ax-b|/|b|.
    */
    RealScalar error() const
    {
        eigen_assert(m_isInitialized && "ConjugateGradient is not initialized.");
        return m_error;
    }

    /** \returns the solution x of \f$ A x = b \f$ using the current decomposition of A
    * and \a x0 as an initial solution.
    *
    * \sa solve(), compute()
    */
    template <typename Rhs, typename Guess> inline const SolveWithGuess<Derived, Rhs, Guess> solveWithGuess(const MatrixBase<Rhs>& b, const Guess& x0) const
    {
        eigen_assert(m_isInitialized && "Solver is not initialized.");
        eigen_assert(derived().rows() == b.rows() && "solve(): invalid number of rows of the right hand side matrix b");
        return SolveWithGuess<Derived, Rhs, Guess>(derived(), b.derived(), x0);
    }

    /** \returns Success if the iterations converged, and NoConvergence otherwise. */
    ComputationInfo info() const
    {
        eigen_assert(m_isInitialized && "IterativeSolverBase is not initialized.");
        return m_info;
    }

    /** \internal */
    template <typename Rhs, typename DestDerived> void _solve_with_guess_impl(const Rhs& b, SparseMatrixBase<DestDerived>& aDest) const
    {
        eigen_assert(rows() == b.rows());

        Index rhsCols = b.cols();
        Index size = b.rows();
        DestDerived& dest(aDest.derived());
        typedef typename DestDerived::Scalar DestScalar;
        Eigen::Matrix<DestScalar, Dynamic, 1> tb(size);
        Eigen::Matrix<DestScalar, Dynamic, 1> tx(cols());
        // We do not directly fill dest because sparse expressions have to be free of aliasing issue.
        // For non square least-square problems, b and dest might not have the same size whereas they might alias each-other.
        typename DestDerived::PlainObject tmp(cols(), rhsCols);
        ComputationInfo global_info = Success;
        for (Index k = 0; k < rhsCols; ++k)
        {
            tb = b.col(k);
            tx = dest.col(k);
            derived()._solve_vector_with_guess_impl(tb, tx);
            tmp.col(k) = tx.sparseView(0);

            // The call to _solve_vector_with_guess_impl updates m_info, so if it failed for a previous column
            // we need to restore it to the worst value.
            if (m_info == NumericalIssue)
                global_info = NumericalIssue;
            else if (m_info == NoConvergence)
                global_info = NoConvergence;
        }
        m_info = global_info;
        dest.swap(tmp);
    }

    template <typename Rhs, typename DestDerived>
    typename internal::enable_if<Rhs::ColsAtCompileTime != 1 && DestDerived::ColsAtCompileTime != 1>::type
    _solve_with_guess_impl(const Rhs& b, MatrixBase<DestDerived>& aDest) const
    {
        eigen_assert(rows() == b.rows());

        Index rhsCols = b.cols();
        DestDerived& dest(aDest.derived());
        ComputationInfo global_info = Success;
        for (Index k = 0; k < rhsCols; ++k)
        {
            typename DestDerived::ColXpr xk(dest, k);
            typename Rhs::ConstColXpr bk(b, k);
            derived()._solve_vector_with_guess_impl(bk, xk);

            // The call to _solve_vector_with_guess updates m_info, so if it failed for a previous column
            // we need to restore it to the worst value.
            if (m_info == NumericalIssue)
                global_info = NumericalIssue;
            else if (m_info == NoConvergence)
                global_info = NoConvergence;
        }
        m_info = global_info;
    }

    template <typename Rhs, typename DestDerived>
    typename internal::enable_if<Rhs::ColsAtCompileTime == 1 || DestDerived::ColsAtCompileTime == 1>::type
    _solve_with_guess_impl(const Rhs& b, MatrixBase<DestDerived>& dest) const
    {
        derived()._solve_vector_with_guess_impl(b, dest.derived());
    }

    /** \internal default initial guess = 0 */
    template <typename Rhs, typename Dest> void _solve_impl(const Rhs& b, Dest& x) const
    {
        x.setZero();
        derived()._solve_with_guess_impl(b, x);
    }

protected:
    void init()
    {
        m_isInitialized = false;
        m_analysisIsOk = false;
        m_factorizationIsOk = false;
        m_maxIterations = -1;
        m_tolerance = NumTraits<Scalar>::epsilon();
    }

    typedef internal::generic_matrix_wrapper<MatrixType> MatrixWrapper;
    typedef typename MatrixWrapper::ActualMatrixType ActualMatrixType;

    const ActualMatrixType& matrix() const { return m_matrixWrapper.matrix(); }

    template <typename InputType> void grab(const InputType& A) { m_matrixWrapper.grab(A); }

    MatrixWrapper m_matrixWrapper;
    Preconditioner m_preconditioner;

    Index m_maxIterations;
    RealScalar m_tolerance;

    mutable RealScalar m_error;
    mutable Index m_iterations;
    mutable ComputationInfo m_info;
    mutable bool m_analysisIsOk, m_factorizationIsOk;
};

}  // end namespace Eigen

#endif  // EIGEN_ITERATIVE_SOLVER_BASE_H
